#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jun 12 13:00:18 2017

@author: luohao
"""

import numpy as np
import os
from keras import optimizers
from keras.utils import np_utils, generic_utils
from keras.models import Sequential, Model, load_model
from keras.layers import Dropout, Flatten, Dense, Input, GlobalAveragePooling2D, GlobalMaxPooling2D
from keras.applications.resnet50 import ResNet50
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.inception_v3 import InceptionV3
from keras.applications.imagenet_utils import preprocess_input
from keras import backend as K
from keras.layers.core import Lambda
from scipy.misc import imread, imresize
from sklearn.preprocessing import normalize
from keras.preprocessing.image import ImageDataGenerator
from keras.initializers import RandomNormal
import tensorflow as tf

import dao
from datasets import load_dataset
from inception_resnet_v1 import InceptionResNetV1
from models.aug import aug_nhw3
import numpy.linalg as la
from six.moves import xrange
import facenet

IMAGE_SIZE = 224
not_header = ['商品名称', '店铺', '商品毛重', '商品编号', '上市时间', '货号', '商品产地',
              '表盘颜色', '适用年龄', '尺码', '适用年龄', '销售渠道类型', '材质成分', '品牌', '品牌名称',
              '上市年份季节', '年份季节', '配件/备注', '产地', '洗涤说明', '面料成分', '材质', '大码女装分类', '分类', '适用季节',
              '面料', '成分含量', '质地', '尺寸', '服装款式细节', '材质1', '自定义', '材质3', '深灰色', '黑色', '深蓝', '深蓝2.0',
              '温馨提示', '藏青色', '灰色材质', '成分', '材质2', '蓝/黑色材质', '白色材质', '纤维成份', '里料', '厚薄', '厚度',
              '上市年份/季节', '年份/季节', '街头', '甜美', '主要颜色', '规格', '罩杯材质', '中老年女装分类', '穿着方式', '尺码', '品牌',
              '通勤', '流行元素', '中老年风格', '服饰工艺', '女装质地', '适用场景', '网纱', '组合形式', '毛线粗细', '花朵', '流行元素',
              '工艺', '面料分类', '风格', '中老年女装图案', '适用对象', '克重', '组合规格', '里料分类', '主要材质', '图案',
              '填充物', '牛仔面料盎司', '功能', '设计裁剪', '人群', '图案文化', '流行元素/工艺', '流行元素/工艺', '弹力', '2.0材质',
              '适用场合', '重量', '类型', '弹性', '款式', '选购热点', '适用人群', '适合人群', '面料成份'
              ]

catalog = [93, 94, 95]  # ['上装','裙装','裤装','特色服饰']
sleeve = ['长袖', '短袖', '无袖', '五分袖', '七分袖', '九分袖']  # 袖长
# collar = ['V领', '不规则领', '圆领', '低圆领', '其它', '高领', '半高领', '翻领', '连帽', '围巾领', 'A字型', '立领', '半高圆领',
#           '一字领', '方领', '旗袍领', 'POLO领', '堆堆领', '无领', '椭圆领', '半开领', 'U型领', '八字领', '高圆领', '低领', '常规领',
#           '娃娃领', '鸡心领', '平领', '双层领', '披肩领', '荡领', '橄榄领', '海军领', '大尖领', '系带领', '无', '荷叶领', '挂脖']  # 领型
collar = ['圆领', '其它', 'V领', '高领', '娃娃领', '立领', '方领', '翻领', '海军领', '半高领', 'POLO领', '连帽', '一字领', '半高圆领', '荷叶领']
coat_length = ['超短款', '短款', '常规款', '中长款', '长款', '不对称衣长']  # 衣长
dress_length = ['超短裙', '短裙', '中裙', '中长裙', '长裙']  # 裙长
outseam = ['短裤', '长裤', '热裤', '五分裤', '七分裤', '九分裤']  # 裤长
waistline = ['低腰', '中腰', '高腰', '松紧腰', '调节腰']  # 腰型


def get_good_ids(base_dir):
    return [name for name in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, name))]


class ImageClass():
    "Stores the paths to images for a given class"

    def __init__(self, good_id, image_paths, catalog, attr):
        self.good_id = good_id
        self.image_paths = image_paths
        self.catalog = catalog
        self.attr = attr

    def __str__(self):
        return self.good_id + ', ' + str(len(self.image_paths)) + ' images' + ','

    def __len__(self):
        return len(self.image_paths)


def get_img_info(base_dir):
    images = []
    good_ids = os.listdir(os.path.join(base_dir, 'comment'))
    good_num = 0
    for good_id in good_ids:
        img_attr = []
        comment_dir = os.path.join(base_dir, 'comment', good_id)
        desc_dir = os.path.join(base_dir, 'desc', good_id)
        # catalog_categorical = np.zeros(len(catalog))
        # catalog_categorical = []
        # 上装、裙装、裤装
        # 袖长
        sleeve_categorical = np.zeros(len(sleeve))
        # sleeve_categorical = [0]*len(sleeve)
        # # 领型
        collar_categorical = np.zeros(len(collar))
        # # 衣长
        coat_length_categorical = np.zeros(len(coat_length))
        #
        # # 裙长
        dress_length_categorical = np.zeros(len(dress_length))
        # # 袖长
        # # 裤长
        outseam_categorical = np.zeros(len(outseam))
        # # 腰型
        waistline_categorical = np.zeros(len(waistline))

        # 图片路径
        image_paths = [os.path.join(comment_dir, name) for name in os.listdir(comment_dir)]
        image_paths += [os.path.join(desc_dir, name) for name in os.listdir(desc_dir)]
        str_sql = "select catalogID,description from t_product where sourceProductID='%s'" % good_id
        result = dao.select_one(str_sql)
        if result:
            # 商品分类
            catalog_id = int(result['catalogID'])
            # catalog_categorical[catalog.index(catalog_id)] = 1.
            # print(catalog_categorical)
            catalog_categorical = np_utils.to_categorical(catalog.index(catalog_id), len(catalog))

            # 商品属性
            try:
                attrs = eval(result['description'])
                for key in attrs.keys():
                    if key in not_header:
                        continue
                    attr = attrs.get(key)  # [0:3]
                    if key == '袖长':
                        if attr == '中袖':
                            attr = '五分袖'
                        if not catalog_id ==95:
                            sleeve_categorical[sleeve.index(attr)] = 1.
                    elif key == '领型':
                        collar_categorical[collar.index(attr)] = 1.
                    elif key == '衣长':
                        coat_length_categorical[coat_length.index(attr)] = 1.
                    elif key == '裙长':
                        dress_length_categorical[dress_length.index(attr)] = 1.
                    elif key == '裤长':
                        if catalog_id ==95:
                            outseam_categorical[outseam.index(attr)] = 1.
                    elif key == '腰型':
                        if catalog_id ==95:
                            waistline_categorical[waistline.index(attr)] = 1.
            except Exception as e:
                print('catalog_id:' + str(catalog_id) + ',attr eval error:' + key + ',good_id:' + good_id)
                print(e)
                continue
            # img_attr.extend(catalog_categorical)
            img_attr.extend(sleeve_categorical)
            img_attr.extend(collar_categorical)
            img_attr.extend(coat_length_categorical)
            img_attr.extend(dress_length_categorical)
            img_attr.extend(outseam_categorical)
            img_attr.extend(waistline_categorical)
            images.append(
                ImageClass(good_id, image_paths, catalog_categorical, np.array(img_attr).reshape(1, len(img_attr))))
    return images


def triplet_loss(y_true, y_pred):
    y_pred = K.l2_normalize(y_pred, axis=1)
    batch = PN
    ref1 = y_pred[0:batch, :]
    pos1 = y_pred[batch:batch + batch, :]
    neg1 = y_pred[batch + batch:3 * batch, :]
    dis_pos = K.sum(K.square(ref1 - pos1), axis=1, keepdims=True)
    dis_neg = K.sum(K.square(ref1 - neg1), axis=1, keepdims=True)
    dis_pos = K.sqrt(dis_pos)
    dis_neg = K.sqrt(dis_neg)
    a1 = 0.2
    d1 = K.maximum(0.0, dis_pos - dis_neg + a1)
    return K.mean(d1)


def get_triplet_data(images_info, PN):
    nrof_classes = len(images_info)
    class_indices = np.arange(nrof_classes)
    np.random.shuffle(class_indices)

    images = np.zeros((3 * PN, IMAGE_SIZE, IMAGE_SIZE, 3))
    #catalog_labels = np.zeros((3 * PN, len(images_info[0].catalog[0])))
    catalog_labels = np.zeros((3 * PN, 3))
    attr_labels = np.zeros((3 * PN, len(images_info[0].attr[0])))

    for i in range(PN):
        class_index = class_indices[i]
        nrof_images_in_class = len(images_info[class_index])
        image_indices = np.arange(nrof_images_in_class)
        np.random.shuffle(image_indices)
        # a,p
        images[i, :, :, :] = get_img_array(images_info[class_index].image_paths[image_indices[0]])
        attr_labels[i] = images_info[class_index].attr[0]
        catalog_labels[i] = images_info[class_index].catalog
        images[i + PN, :, :, :] = get_img_array(images_info[class_index].image_paths[image_indices[0]])
        for ind in range(1,nrof_images_in_class):#image_indices images_info[class_index].image_paths:
            path = images_info[class_index].image_paths[image_indices[ind]]
            if 'desc' in path:
                images[i + PN, :, :, :] = get_img_array(path)
                break
        attr_labels[i + PN] = images_info[class_index].attr[0]
        catalog_labels[i + PN] = images_info[class_index].catalog

        # neg
        n_class_index = np.random.randint(nrof_classes - 1)
        if n_class_index == class_index:
            n_class_index += 1
        for path in images_info[n_class_index].image_paths:
            if 'desc' in path:
                images[i + PN * 2, :, :, :] = get_img_array(path)
                break
        attr_labels[i + PN * 2] = images_info[n_class_index].attr[0]
        catalog_labels[i + PN * 2] = images_info[n_class_index].catalog
    return images, catalog_labels, attr_labels


def get_random_batch(images_info, batch_num=50):
    nrof_classes = len(images_info)
    if nrof_classes <= batch_num:
        return images_info
    batch_images = []
    class_indices = np.arange(nrof_classes)
    np.random.shuffle(class_indices)
    for i in range(batch_num):
        batch_images.append(images_info[i])
    return batch_images


def get_img_array(img_path):
    img = imread(img_path)
    if not img.shape == (IMAGE_SIZE, IMAGE_SIZE, 3):
        img = imresize(IMAGE_SIZE, IMAGE_SIZE, 3)
    # img = img / 255.
    return img


if __name__ == '__main__':
    images = get_img_info('E:/cdn/datasets/triplet')
    random_batch = 50
    SN = 3
    PN = 20
    identity_num = len(images[0].attr[0])

    # load pre-trained resnet50
    base_model = ResNet50(weights='imagenet', include_top=False,
                          input_tensor=Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3)))
    x = base_model.output
    # feature = Flatten(name='flatten')(x)
    x = GlobalAveragePooling2D()(x)
    fc1 = Dropout(0.5)(x)
    feature = Dense(512, use_bias=False, name='Bottleneck')(fc1)
    x = Dense(1024, activation='relu')(x)  # new FC layer, random init
    preds = Dense(identity_num, activation='softmax', name='fc8')(x)  # new softmax layer

    # x = GlobalAveragePooling2D()(x)
    # let's add a fully-connected layer
    # feature = Dense(512, activation=None)(x)

    # fc1 = Dropout(0.5)(feature)
    # preds = Dense(identity_num, activation='softmax', name='fc8')(fc1)  # default glorot_uniform
    class_triplet_model = Model(inputs=base_model.input, outputs=[preds, feature])
    # class_triplet_model = Model(inputs=base_model.input, outputs=preds)
    layer_num = len(base_model.layers)
    for layer in base_model.layers:
        layer.trainable = False
    # base_model.layers[layer_num - 1].trainable = True

    # adam = optimizers.Adam(lr=0.00001)
    lr = 0.001
    adam = optimizers.adam(lr)
    # adam = optimizers.SGD(lr=0.0001, momentum=0.9, decay=0.0005)
    class_triplet_model.compile(optimizer=adam, loss=['binary_crossentropy', triplet_loss],
                                loss_weights=[1.0, 1.0], metrics={'fc8': 'accuracy'})
    # class_triplet_model.compile(optimizer=adam, loss='binary_crossentropy',metrics=['accuracy'])

    datagen = ImageDataGenerator(
        featurewise_center=True,
        featurewise_std_normalization=True,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True)

    # train
    for i in range(30):
        print('random_batch_i:' + str(i))
        batch_images = get_random_batch(images, random_batch)
        for j in range(300):  # 最大服装图片有四五百
            print('i_' + str(i) + '_sub_batch_j:' + str(j))
            train_img, train_label = get_triplet_data(batch_images, PN)
            datagen.fit(train_img, augment=True)
            # # for epoch in range(5):
            epoch = 0
            for x_train, y_train in datagen.flow(train_img, train_label, batch_size=PN * SN, shuffle=False):
                print('i_' + str(i) + '_sub_batch_j_' + str(j) + '_epoch:' + str(epoch))
                # train_img = aug_nhw3(train_img)
                # train_img = preprocess_input(train_img)
                class_triplet_model.fit(x_train,
                                        y=[y_train, np.ones([PN * SN, 512])],
                                        shuffle=False, epochs=1, batch_size=PN * SN)
                epoch += 1
                if epoch > 2:
                    break
                    # class_triplet_model.fit(x_train,y=y_train,shuffle=False, epochs=1, batch_size=PN * SN)
    # 保存模型
    model_save = os.path.join('E:/cdn/models', 'triplet_resnet50.h5')
    class_triplet_model.save(model_save)
    class_triplet_model.summary()
